import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import f1_score
from scipy import stats
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

# Load data
df = pd.read_csv('./data/covid_fewshot_large.csv')
fr = pd.read_csv('./data/covid_fewshot_base.csv')
cov = pd.read_csv('./data/covid_llama_peft_res.csv')

# Values for reference lines on the plots
human_coders = round(f1_score(cov['coder2nonc'], cov['coder1nonc']), 2)
print(f"Human benchmark: {human_coders}")

llama_zs = round(f1_score(cov['non_comp'], cov['llama_zs']), 2)
print(f"Llama zero shot: {llama_zs}")

peft_f1s = [] # Take the average of all 30 PEFTs for the 25-shot reference line
for col in cov.columns[7:]:
    llama_peft = f1_score(cov['non_comp'], cov[col])
    peft_f1s.append(llama_peft)
llama_peft = float(round(np.mean(peft_f1s), 2))
print(f"Llama 25 shot: {llama_peft}")

supervised = 0.76 #From the Block Jr. et al. 2022 paper



# Define metric and text size
metric = 'f1'
textsize = 24

# Define the x positions you want to use
x_positions = [0, 10, 25, 50, 100]

# Create the figure with two subplots (sharing the y axis)
fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharey=True)
custom_palette = sns.color_palette("colorblind", 2)

# ---------------------
# Left subplot: DEBATE Base (fr)
# ---------------------
# Prepare data for each x value
data_fr = [fr[fr['n'] == pos][metric].dropna().values for pos in x_positions]
# Create a boxplot using Matplotlib's boxplot with custom positions
bp_fr = axes[0].boxplot(data_fr, positions=x_positions, widths=5, patch_artist=True, showfliers=True)

# Get median values for the dots
medians_fr = [np.median(data) for data in data_fr]
# Add dots at median values
axes[0].scatter(x_positions, medians_fr, color='black', s=50, zorder=3)

# Color the boxes using the custom palette
for box in bp_fr['boxes']:
    box.set_facecolor(custom_palette[1])

axes[0].set_title('DEBATE Base', fontsize=textsize, fontweight='bold')
axes[0].set_xlim(-5, 105)
axes[0].set_xticks(x_positions)
axes[0].set_xticklabels(x_positions, rotation=0)
axes[0].set_xlabel('Training Samples', fontsize=textsize)
axes[0].set_ylabel('F1', fontsize=textsize, fontweight='bold')
axes[0].set_yticks([.3, .4, .5, .6, .7, .8, .9])
axes[0].set_yticklabels(['30%', '40%', '50%', '60%', '70%', '80%', '90%'])
axes[0].tick_params(axis='both', labelsize=textsize - 4)
axes[0].grid(False)

# ---------------------
# Right subplot: DEBATE Large (df)
# ---------------------
# Prepare data for each x value
data_df = [df[df['n'] == pos][metric].dropna().values for pos in x_positions]
# Create a boxplot with the same approach
bp_df = axes[1].boxplot(data_df, positions=x_positions, widths=5, patch_artist=True, showfliers=True)

# Get median values for the dots
medians_df = [np.median(data) for data in data_df]
# Add dots at median values
axes[1].scatter(x_positions, medians_df, color='black', s=50, zorder=3)

# Color the boxes using the custom palette
for box in bp_df['boxes']:
    box.set_facecolor(custom_palette[0])

axes[1].set_title('DEBATE Large', fontsize=textsize, fontweight='bold')
axes[1].set_xlim(-5, 105)
axes[1].set_xticks(x_positions)
axes[1].set_xticklabels(x_positions, rotation=0)
axes[1].set_xlabel('Training Samples', fontsize=textsize)
axes[1].tick_params(axis='both', labelsize=textsize - 4)
axes[1].grid(False)

# ---------------------
# Add common reference lines to both subplots
# ---------------------
for ax in axes:
    ax.axhline(y=human_coders, color='black', linestyle='--', linewidth=2, alpha=0.5)
    ax.axhline(y=supervised, color='black', linestyle='--', linewidth=2, alpha=0.5)
    ax.axhline(y=llama_peft, color='black', linestyle='--', linewidth=2, alpha=0.5)
    ax.axhline(y=llama_zs, color='black', linestyle='--', linewidth=2, alpha=0.5)

for bp in [bp_fr, bp_df]:
    for median in bp['medians']:
        median.set(color='black', linewidth=2)

# ---------------------
# Add annotations to the right subplot
# ---------------------
axes[1].text(axes[1].get_xlim()[1] -2, human_coders - .02, 'Human Coders', va='center', ha='right', size=textsize)
axes[1].text(axes[1].get_xlim()[1] -2, supervised - .02, 'Supervised Transformer', va='center', ha='right', size=textsize)
axes[1].text(axes[1].get_xlim()[1] -2, llama_peft - .02, 'Llama 3.1 8B (25-Shot)', va='center', ha='right', size=textsize)
axes[1].text(axes[1].get_xlim()[1] -2, llama_zs + .02, 'Llama 3.1 8B (0-Shot)', va='center', ha='right', size=textsize)

# Remove individual x-axis labels
axes[0].set_xlabel("")
axes[1].set_xlabel("")

plt.tight_layout(rect=[0, 0.03, 1, 1])  # Leave 5% space at the bottom
fig.text(0.54, 0.02, "Training Samples", ha='center', fontsize=textsize, fontweight = 'bold')
# Adjust layout and save the figure
plt.savefig('./figures/figure_4.png', dpi=300)